#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sat Feb 17 11:44:57 2024

@author: jcfq2
"""

import os

jwst_dir='/Users/jcfq2/data/observations/jwst'

os.chdir(jwst_dir)

import matplotlib.pyplot as plt
from astropy.visualization import make_lupton_rgb
from astropy.table import Table
import ch4_fiddlesticks as ch4
import pandas as pd
import h3ppy
import glob
from scipy.optimize import curve_fit
import spiceypy as spice
from astropy.io import fits
import numpy as np
from importlib import reload
import JWSTSolarSystemPointing as jssp
from scipy.interpolate import interp1d

reload(jssp)
#import jwst_uranus as jwstu

# In[0]

kernel_dir = '/Users/jcfq2/data/observations/jwst/kernels/'
jssp.load_kernels(kdir=kernel_dir)

# Set up a h3ppy object too, always useful
h3p_model = h3ppy.h3p()

# In[1]

fn = 'jw05308001001_03119_g395h-f290lp_s3d.fits'
file_dir = '/Users/jcfq2/data/observations/jwst/5308_dither_combined/'

file = file_dir + fn
# Create a JWSTSolarSystemPointing object that helps with lots of things JWST
geo = jssp.JWSTSolarSystemPointing(file)

# Caluclate the geometry for the full IFU cube. This is a three dimensional arrray.
cube = geo.full_fov()

# Get the wavelength scale for this observation - should be the same for all NIRSpec G395H/F290L observations.
wave = geo.get_wavelength()

# These are the available geometric parameters - if you need something that's not here, let me know!
# Only using Pandas to make this table pretty, generally not a fan
pd.DataFrame(geo.keys, columns=["Parameter"])



# %% make an intensity map using the PSG subtraction

import ch4_fiddlesticks as ch4

files = sorted(glob.glob(
    "/Users/jcfq2/data/observations/jwst/5308_dither_separated/*.fits"))
geo = jssp.JWSTSolarSystemPointing(files[0])
wave = geo.get_wavelength()
#  just to get geo for conversion

nam=np.load('saturn_spectra_map_2xx.npy')
# namlos=np.load('saturn_spectra_los_shell_2.npy')
nac=np.load('saturn_spectra_count_2xx.npy')

scale=2




# %%
# for ww in range(len(nam[0,0,:])):
    # nam[:,:,ww]=nam[:,:,ww]*namlos
# nam=nam*namlos
# namlos=0.0




nam[-90*scale:-45*scale,:]=nam[-90*scale:-45*scale,:]+nam[0:45*scale,:]
nac[-90*scale:-45*scale,:]=nac[-90*scale:-45*scale,:]+nac[0:45*scale,:]
nam[45*scale:90*scale,:]=nam[45*scale:90*scale,:]+nam[-45*scale:,:]
nac[45*scale:90*scale,:]=nac[45*scale:90*scale,:]+nac[-45*scale:,:]

# plt.imshow(nam[:,:,600]/nac[:,:,600])
# plt.show()



nam=nam[45*scale:-45*scale,:]
nac=nac[45*scale:-45*scale,:]

spec_cube=nam/nac
spec_cube=np.nan_to_num(spec_cube)

# %%

fig = plt.figure(figsize=(12,8),dpi=300)

numcount=plt.imshow(np.rot90(nac[:,220:,500]),cmap='tab20c')
cbar=fig.colorbar(numcount,location='bottom',aspect=4)

plt.show()



# %%


# nam=0.0
# nac=0.0

file = files[-2]
# Create a JWSTSolarSystemPointing object that helps with lots of things JWST
geo = jssp.JWSTSolarSystemPointing(file)

# Caluclate the geometry for the full IFU cube. This is a three dimensional arrray.
cube = geo.full_fov()

# Get the wavelength scale for this observation - should be the same for all NIRSpec G395H/F290L observations.
wave = geo.get_wavelength()

# this didn't work with PSG, and is outside the range used by CH4
wavemin_331 = 3.31
wavemax_331 = 3.34
whw_331 = np.argwhere((wave > wavemin_331) & (wave < wavemax_331)).flatten()


wavemin_345 = 3.445
wavemax_345 = 3.465
whw_345 = np.argwhere((wave > wavemin_345) & (wave < wavemax_345)).flatten()


wavemin_353 = 3.5255
wavemax_353 = 3.565
whw_353 = np.argwhere((wave > wavemin_353) & (wave < wavemax_353)).flatten()


wavemin_360 = 3.61
wavemax_360 = 3.635
whw_360 = np.argwhere((wave > wavemin_360) & (wave < wavemax_360)).flatten()


wavemin_366 = 3.662
wavemax_366 = 3.675
whw_366 = np.argwhere((wave > wavemin_366) & (wave < wavemax_366)).flatten()


# correction for 3.9008

wavemin_371 = 3.709
wavemax_371 = 3.721
whw_371 = np.argwhere((wave > wavemin_371) & (wave < wavemax_371)).flatten()


# correction for 3.9008

wavemin_390 = 3.898
wavemax_390 = 3.911
whw_390 = np.argwhere((wave > wavemin_390) & (wave < wavemax_390)).flatten()


# correction for 3953

wavemin_392 = 3.923
wavemax_392 = 3.9313
whw_392 = np.argwhere((wave > wavemin_392) & (wave < wavemax_392)).flatten()


# correction for 3953

wavemin_3953 = 3.93
wavemax_3953 = 3.958
whw_3953 = np.argwhere((wave > wavemin_3953) & (wave < wavemax_3953)).flatten()

# correction for 3953
wavemin_396 = 3.961
wavemax_398 = 3.979
whw_3966 = np.argwhere((wave > wavemin_396) & (wave < wavemax_398)).flatten()

# correction for 3953
wavemin_398 = 3.98
wavemax_399 = 3.991
whw_398 = np.argwhere((wave > wavemin_398) & (wave < wavemax_399)).flatten()

# to fit temperature

# wavemin = 3.3
# wavemax = 4.3

# correction for 3953
wavemin_F323 = 3.148
wavemax_F323 = 3.35
whw_F323 = np.argwhere((wave > wavemin_F323) & (wave < wavemax_F323)).flatten()

F323_values=np.loadtxt('F323N_filter.txt',delimiter=' ')

f_F323 = interp1d(F323_values[:,0], F323_values[:,1], kind='linear')  # You can change 'linear' to other interpolation methods like 'cubic'


# whw = np.argwhere((wave > wavemin) & (wave < wavemax)).flatten()

h3p = h3ppy.h3p()
h3p.set(T=500, N=6e14, wave=wave, R=4700)
h3p_model=h3p.model()

im_h3p = np.nanmedian(geo.im[whw_3953, :, :], axis=0)
 
img = nam[:,:,1000]
nam=0.0

intensity = np.zeros_like(img)
ch4_fun = np.zeros_like(intensity)
ch4_hot = np.zeros_like(intensity)

intensity = 0.0

# np.save('saturn_h3p_intensity.npy',intensity)

xs = 45*scale+1
ys = 360*scale+1





# %%


ch4_fun=np.load('saturn_saves/saturn_ch4fun_2xx.npy')
ch4_hot=np.load('saturn_saves/saturn_ch4hot_2xx.npy')




# cheats in the earlier fits
wavemin_331 = 3.27
wavemax_331 = 3.47
whw_331 = np.argwhere((wave > wavemin_331) & (wave < wavemax_331)).flatten()


# cheats in the earlier fits
wavemin_4 = 3.97
wavemax_4 = 3.99
whw_4 = np.argwhere((wave > wavemin_4) & (wave < wavemax_4)).flatten()


old_temp = ch4_hot



# degrees above 45: 30 = 75
startx=16*scale



for xxx in range(startx,xs):
    for yyy in range(ys):
        

        xx=xxx+135*scale#(left/right)
        yy=yyy#(up/down)
        xx_x = xx-135*scale


        if old_temp[yy,xx_x] == 0:
            spec = geo.convert(wave, spec_cube[yy, xx,:])  #auroral
            if np.sum(spec) == 0:
                print(xx/scale,yy/scale)
            else:
    
                print(xxx,yyy,xx,xx_x)
                whw_sub=np.hstack([whw_331,whw_353,whw_360,whw_371,whw_390,whw_392,whw_3953])

                wave_sub=wave[whw_sub]
                
                reload(ch4)
        
                ch4fit = ch4.fit_non_LTE_CH4_JWST(wave)
                fit = ch4fit.fit(wave, spec)
                
    
                ch4_fun[yy,xx_x]=np.nansum(ch4fit.ch4_fun[whw_sub]*ch4fit.fun_scaling)
                ch4_hot[yy,xx_x]=np.nansum(ch4fit.ch4_hot[whw_sub]*ch4fit.hot_scaling)
    
    
    
                fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4), gridspec_kw={'width_ratios': [6,1]})
                

                ax1.plot(wave[whw_sub],spec[whw_sub],label='raw spec')

                ax1.plot(wave[whw_sub],ch4fit.ch4_fun[whw_sub]*ch4fit.fun_scaling,label='methane_fund',linestyle='dotted')
                ax1.plot(wave[whw_sub],ch4fit.ch4_hot[whw_sub]*ch4fit.hot_scaling,label='methane_hot',linestyle='dotted')
                # ax1.plot(ch4fit.wave[whw_sub],(ch4fit.ch4_fun[whw_sub]+ch4fit.ch4_hot[whw_sub])*np.max(spec[whw_sub]),label='methane')
                  
                ax2.imshow(ch4_fun+ch4_hot,vmin= 0,vmax=0.001)

                

                
                plt.show()
                
        else:
            print(xx/scale,yy/scale,'pre_fit')

np.save('saturn_saves/saturn_ch4fun_2xx.npy',ch4_fun)
np.save('saturn_saves/saturn_ch4hot_2xx.npy',ch4_hot)



